{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from minepy import MINE\n",
    "from joblib  import Parallel,delayed\n",
    "import numpy as np\n",
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mic(x,y):\n",
    "    '''输入为dataframe'''\n",
    "    if type(x) == pd.DataFrame:\n",
    "        x = x.prob.values.ravel()\n",
    "        y = y.prob.values.ravel()\n",
    "    m = MINE(est ='mic_e')\n",
    "    m.compute_score(x,y)\n",
    "    return m.mic()\n",
    "\n",
    "def cal_raw(x,pred_list):\n",
    "    result = Parallel(n_jobs=len(pred_list),verbose=10)(delayed(mic)(x,y) for y in pred_list)\n",
    "    return result\n",
    "\n",
    "def cal_matrix(pred_list):\n",
    "    result = Parallel(n_jobs=len(pred_list),verbose=10)(delayed(cal_raw)(x,pred_list) for x in pred_list)\n",
    "    return np.array(result)\n",
    "\n",
    "def plot_mic_matrix(mic_matrix,ticks):\n",
    "    plt.figure(figsize=(16,16))\n",
    "    plt.xticks(fontsize=20)\n",
    "    plt.yticks(fontsize=20)\n",
    "    sns.heatmap(mic_matrix,linewidths=1,vmax=1.0,\n",
    "            square=True,linecolor='white',annot=True,xticklabels=ticks,yticklabels =ticks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_names = [file for file in os.listdir('./preds_for_bagging/') if file[0]!='.']\n",
    "\n",
    "pred_list = []\n",
    "for file in file_names:\n",
    "    pred = pd.read_csv(os.path.join('./preds_for_bagging/',file))\n",
    "    pred_list.append(pred)\n",
    "mic_matrix = cal_matrix(pred_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_mic_matrix(mic_matrix,file_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
